{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 2118/2118 [00:00<00:00, 6751.07it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 18841/18841 [00:05<00:00, 3730.90it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 478072.11it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 279964.01it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 3407400.32it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8756.29it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4251.64it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 122800.84it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 59226.17it/s]\n",
      "Processing text_left with chain_transform of NgramLetterUnit => WordHashingUnit: 100%|██████████| 2118/2118 [00:00<00:00, 5882.98it/s]\n",
      "Processing text_right with chain_transform of NgramLetterUnit => WordHashingUnit: 100%|██████████| 18841/18841 [00:08<00:00, 2111.86it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.CDSSMPreprocessor()\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8118.95it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4421.77it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 82961.27it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 63304.89it/s]\n",
      "Processing text_left with chain_transform of NgramLetterUnit => WordHashingUnit: 100%|██████████| 122/122 [00:00<00:00, 6042.02it/s]\n",
      "Processing text_right with chain_transform of NgramLetterUnit => WordHashingUnit: 100%|██████████| 1115/1115 [00:00<00:00, 2165.76it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 4295.23it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4193.45it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 94026.68it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 70195.00it/s]\n",
      "Processing text_left with chain_transform of NgramLetterUnit => WordHashingUnit: 100%|██████████| 237/237 [00:00<00:00, 6278.78it/s]\n",
      "Processing text_right with chain_transform of NgramLetterUnit => WordHashingUnit: 100%|██████████| 2300/2300 [00:01<00:00, 2156.54it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to CDSSMModel.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10, 9644)     0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 40, 9644)     0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_11 (Conv1D)              (None, 10, 300)      8679900     text_left[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_12 (Conv1D)              (None, 40, 300)      8679900     text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dropout_11 (Dropout)            (None, 10, 300)      0           conv1d_11[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dropout_12 (Dropout)            (None, 40, 300)      0           conv1d_12[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "global_max_pooling1d_11 (Global (None, 300)          0           dropout_11[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "global_max_pooling1d_12 (Global (None, 300)          0           dropout_12[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_22 (Dense)                (None, 300)          90300       global_max_pooling1d_11[0][0]    \n",
      "__________________________________________________________________________________________________\n",
      "dense_24 (Dense)                (None, 300)          90300       global_max_pooling1d_12[0][0]    \n",
      "__________________________________________________________________________________________________\n",
      "dense_23 (Dense)                (None, 128)          38528       dense_22[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dense_25 (Dense)                (None, 128)          38528       dense_24[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dot_6 (Dot)                     (None, 1)            0           dense_23[0][0]                   \n",
      "                                                                 dense_25[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dense_26 (Dense)                (None, 1)            2           dot_6[0][0]                      \n",
      "==================================================================================================\n",
      "Total params: 17,617,458\n",
      "Trainable params: 17,617,458\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.CDSSMModel()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['filters'] = 300\n",
    "model.params['kernel_size'] = 3\n",
    "model.params['strides'] = 1\n",
    "model.params['padding'] = 'same'\n",
    "model.params['conv_activation_func'] = 'tanh'\n",
    "model.params['w_initializer'] = 'glorot_normal'\n",
    "model.params['b_initializer'] = 'zeros'\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 300\n",
    "model.params['mlp_num_fan_out'] = 128\n",
    "model.params['mlp_activation_func'] = 'tanh'\n",
    "model.params['dropout_rate'] = 0.8\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))\n",
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "32/32 [==============================] - 24s 748ms/step - loss: 0.7844\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.483600 - normalized_discounted_cumulative_gain@5(0):0.553078 - mean_average_precision(0):0.512808\n",
      "Epoch 2/20\n",
      "32/32 [==============================] - 29s 898ms/step - loss: 0.5235\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.488520 - normalized_discounted_cumulative_gain@5(0):0.561723 - mean_average_precision(0):0.510881\n",
      "Epoch 3/20\n",
      "32/32 [==============================] - 23s 706ms/step - loss: 0.2984\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.484800 - normalized_discounted_cumulative_gain@5(0):0.555596 - mean_average_precision(0):0.512422\n",
      "Epoch 4/20\n",
      "32/32 [==============================] - 25s 775ms/step - loss: 0.1954\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.475072 - normalized_discounted_cumulative_gain@5(0):0.562940 - mean_average_precision(0):0.512383\n",
      "Epoch 5/20\n",
      "32/32 [==============================] - 23s 721ms/step - loss: 0.1056\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.487088 - normalized_discounted_cumulative_gain@5(0):0.566688 - mean_average_precision(0):0.519216\n",
      "Epoch 6/20\n",
      "32/32 [==============================] - 23s 706ms/step - loss: 0.0727\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.478561 - normalized_discounted_cumulative_gain@5(0):0.562530 - mean_average_precision(0):0.512396\n",
      "Epoch 7/20\n",
      "32/32 [==============================] - 28s 884ms/step - loss: 0.0505\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.486433 - normalized_discounted_cumulative_gain@5(0):0.567996 - mean_average_precision(0):0.522548\n",
      "Epoch 8/20\n",
      "32/32 [==============================] - 26s 822ms/step - loss: 0.0544\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.498047 - normalized_discounted_cumulative_gain@5(0):0.575753 - mean_average_precision(0):0.526914\n",
      "Epoch 9/20\n",
      "32/32 [==============================] - 22s 700ms/step - loss: 0.0386\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.495465 - normalized_discounted_cumulative_gain@5(0):0.574323 - mean_average_precision(0):0.523699\n",
      "Epoch 10/20\n",
      "32/32 [==============================] - 24s 748ms/step - loss: 0.0290\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.502232 - normalized_discounted_cumulative_gain@5(0):0.581065 - mean_average_precision(0):0.527918\n",
      "Epoch 11/20\n",
      "32/32 [==============================] - 28s 867ms/step - loss: 0.0297\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.496569 - normalized_discounted_cumulative_gain@5(0):0.576849 - mean_average_precision(0):0.521037\n",
      "Epoch 12/20\n",
      "32/32 [==============================] - 24s 735ms/step - loss: 0.0212\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.511699 - normalized_discounted_cumulative_gain@5(0):0.589066 - mean_average_precision(0):0.532874\n",
      "Epoch 13/20\n",
      "32/32 [==============================] - 23s 705ms/step - loss: 0.0197\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.491434 - normalized_discounted_cumulative_gain@5(0):0.570174 - mean_average_precision(0):0.513504\n",
      "Epoch 14/20\n",
      "32/32 [==============================] - 28s 869ms/step - loss: 0.0132\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.504107 - normalized_discounted_cumulative_gain@5(0):0.575971 - mean_average_precision(0):0.519329\n",
      "Epoch 15/20\n",
      "32/32 [==============================] - 26s 826ms/step - loss: 0.0150\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.503686 - normalized_discounted_cumulative_gain@5(0):0.580463 - mean_average_precision(0):0.524712\n",
      "Epoch 16/20\n",
      "32/32 [==============================] - 25s 778ms/step - loss: 0.0123\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.499589 - normalized_discounted_cumulative_gain@5(0):0.576160 - mean_average_precision(0):0.516276\n",
      "Epoch 17/20\n",
      "32/32 [==============================] - 22s 699ms/step - loss: 0.0165\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.501098 - normalized_discounted_cumulative_gain@5(0):0.577585 - mean_average_precision(0):0.525730\n",
      "Epoch 18/20\n",
      "32/32 [==============================] - 28s 867ms/step - loss: 0.0089\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.502192 - normalized_discounted_cumulative_gain@5(0):0.575193 - mean_average_precision(0):0.527620\n",
      "Epoch 19/20\n",
      "32/32 [==============================] - 23s 722ms/step - loss: 0.0074\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.508068 - normalized_discounted_cumulative_gain@5(0):0.580008 - mean_average_precision(0):0.534251\n",
      "Epoch 20/20\n",
      "32/32 [==============================] - 28s 881ms/step - loss: 0.0098\n",
      "Validation: loss:nan - normalized_discounted_cumulative_gain@3(0):0.504778 - normalized_discounted_cumulative_gain@5(0):0.572644 - mean_average_precision(0):0.531077\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=1, use_multiprocessing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
